{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建语音识别系统"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据集下载"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 下载数据集：[BZNSYP](https://www.data-baker.com/data/index/TNtts/)\n",
    "- 大约12小时\n",
    "- 来自同一个说话人\n",
    "- 本次实验我们采用pinyin作为我们的识别结果，来构建一个语音识别系统\n",
    "\n",
    "\n",
    "2. 构建数据集\n",
    "- 将数据集放到dataset目录下，dataset目录如下：\n",
    "\n",
    "    ```\n",
    "    dataset\n",
    "    ├── PhoneLabeling\n",
    "    │   ├── 000001.interval\n",
    "    ├── ProsodyLabeling\n",
    "    │   ├── 000001-010000.txt\n",
    "    ├── Wave\n",
    "    │   ├── 000001.wav\n",
    "    ```\n",
    "\n",
    "- 运行 `splitdata/split_data.py` 划分数据集，最后dataset目录下会多一个split目录\n",
    "\n",
    "    ```\n",
    "    dataset\n",
    "    ├── split\n",
    "    │   ├── train\n",
    "    │   │   ├── wav.scp\n",
    "    │   │   ├── pinyin\n",
    "    │   ├── dev\n",
    "    │   │   ├── wav.scp\n",
    "    │   │   ├── pinyin\n",
    "    │   ├── test\n",
    "    │   │   ├── wav.scp\n",
    "    │   │   ├── pinyin\n",
    "    ```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据提取"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据提取的框架已经构建好，位于 data/dataloader.py 的BZNSYP类中，需要完成：\n",
    "- 语音特征提取\n",
    "- 文本处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 语音特征提取(`data.dataloader.extract_audio_features`函数)\n",
    "\n",
    "    -  可以把实验一完成的特征提取给放进去\n",
    "    -  要保证返回的结果为一个tensor，并且维度为(L,f)\n",
    "\n",
    "2. 文本处理\n",
    "    - 构建tokenizer：`tokenizer.tokenizer.Tokenizer`\n",
    "    - tokenizer需要做到：\n",
    "        - 将一段字符映射成token id\n",
    "        - 需要完成Tokenizer框架中的TODO"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### tokenizer构建"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 构建字典\n",
    "    - 构建字典的文件已经写好，tokenizer/gen_vocab.py\n",
    "    - 可以修改一下对应的路径，vocab结果如下\n",
    "\n",
    "        ```\n",
    "        huang\n",
    "        cheng\n",
    "        lo\n",
    "        ...\n",
    "        ```\n",
    "\n",
    "2. 完成Tokenizer的TODO部分\n",
    "    - call函数\n",
    "    - decode函数\n",
    "    - 注意特殊字符\\<pad\\>,\\<unk\\>, \\<sos\\>, \\<eos\\>,\\<blk\\>, \" \""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tokenizer.tokenizer import Tokenizer\n",
    "pinyin_list1 = ['wo', 'men', 'cheng', 'shi', 'de', 'fu', 'su', 'you', 'lai']\n",
    "pinyin_list2 = [\"A\", 'wo', 'men', 'cheng']\n",
    "tokenizer = Tokenizer(\"./tokenizer/vocab.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['wo', 'men', 'cheng', 'shi', 'de', 'fu', 'su', 'you', 'lai']\n"
     ]
    }
   ],
   "source": [
    "id_list1 = tokenizer(pinyin_list1)\n",
    "print(tokenizer.decode(id_list1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['wo', 'men', 'cheng']\n"
     ]
    }
   ],
   "source": [
    "id_list2 = tokenizer(pinyin_list2)\n",
    "print(tokenizer.decode(id_list2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据构建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def pre_emphasis(x: np.ndarray, alpha: float = 0.97) -> np.ndarray:\n",
    "    return np.append(x[0], x[1:] - alpha * x[:-1])\n",
    "\n",
    "def framing(x: np.ndarray, sr: int, frame_lenth: float = 0.025, frame_gap: float = 0.010) -> np.ndarray:\n",
    "    frame_len, frame_step = int(round(frame_lenth * sr)), int(round(frame_gap * sr))\n",
    "    signal_len = len(x)\n",
    "\n",
    "    if signal_len <= frame_len:\n",
    "        num_frames = 1\n",
    "    else:\n",
    "        num_frames = 1 + int(np.ceil((signal_len - frame_len) / frame_step))\n",
    "\n",
    "    pad_signal_length = (num_frames - 1) * frame_step + frame_len\n",
    "    amount_to_pad = pad_signal_length - signal_len\n",
    "    \n",
    "    pad_signal = np.pad(x, (0, max(0, amount_to_pad)), mode='constant', constant_values=0)\n",
    "    \n",
    "    frame_indices_offset = np.arange(frame_len)\n",
    "    frame_start_points = np.arange(num_frames) * frame_step\n",
    "    \n",
    "    indices = frame_start_points[:, np.newaxis] + frame_indices_offset[np.newaxis, :]\n",
    "    \n",
    "    frames = pad_signal[indices.astype(np.int32, copy=False)]\n",
    "    return frames\n",
    "\n",
    "def add_window(frame_sig: np.ndarray, sr: int, frame_len_s: float = 0.025) -> np.ndarray:\n",
    "    window = np.hamming(int(round(frame_len_s * sr)))\n",
    "    return frame_sig * window\n",
    "\n",
    "def stft(frame_sig: np.ndarray, nfft: int = 512) -> tuple[np.ndarray, np.ndarray]:\n",
    "    frame_spec = np.fft.rfft(frame_sig, n=nfft)\n",
    "    frame_mag = np.abs(frame_spec)\n",
    "    frame_pow = (frame_mag ** 2) / nfft\n",
    "    return frame_mag, frame_pow\n",
    "\n",
    "def get_filter_banks(sr, n_filters=40, nfft=512):\n",
    "    low_freq_mel = 0\n",
    "    high_freq_mel = 2595 * np.log10(1 + (sr / 2) / 700)\n",
    "    \n",
    "    mel_points = np.linspace(low_freq_mel, high_freq_mel, n_filters + 2)\n",
    "    hz_points = 700 * (10 ** (mel_points / 2595) - 1)\n",
    "    bins = np.floor((nfft + 1) * hz_points / sr).astype(int)\n",
    "    filter_banks = np.zeros((n_filters, nfft // 2 + 1))\n",
    "    fft_freqs = np.arange(nfft // 2 + 1)\n",
    "    \n",
    "    for i in range(n_filters):\n",
    "        left, center, right = bins[i:i+3]\n",
    "\n",
    "        left_mask = (left <= fft_freqs) & (fft_freqs < center)\n",
    "        if center != left:\n",
    "            filter_banks[i, left_mask] = (fft_freqs[left_mask] - left) / (center - left)\n",
    "\n",
    "        right_mask = (center <= fft_freqs) & (fft_freqs < right)\n",
    "        if right != center:\n",
    "            filter_banks[i, right_mask] = (right - fft_freqs[right_mask]) / (right - center)\n",
    "    \n",
    "    return filter_banks\n",
    "\n",
    "def get_fbank(frame_pow: np.ndarray, filter_banks: np.ndarray) -> np.ndarray:\n",
    "    return np.dot(frame_pow, filter_banks.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, Dataset\n",
    "from tokenizer.tokenizer import Tokenizer\n",
    "import torch\n",
    "import random\n",
    "import os\n",
    "from utils.utils import collate_with_PAD\n",
    "import librosa\n",
    "\n",
    "def extract_audio_features(wav_file:str)->torch.Tensor:\n",
    "\n",
    "    def calc_fbank(x: np.ndarray, sr: int = 16000, n_filters: int = 40, nfft: int = 512) -> np.ndarray:\n",
    "        x = pre_emphasis(x)\n",
    "        frames = framing(x, sr)\n",
    "        frames = add_window(frames, sr)\n",
    "        frame_mag, frame_pow = stft(frames, nfft)\n",
    "        filter_banks = get_filter_banks(sr, n_filters, nfft)\n",
    "        fbank = get_fbank(frame_pow, filter_banks)\n",
    "        return fbank\n",
    "\n",
    "    if not isinstance(wav_file, str):\n",
    "        raise TypeError(f\"Expected string for wav_file\")\n",
    "\n",
    "    y, sr = librosa.load(wav_file, sr=None)\n",
    "    fbank = calc_fbank(y, sr=sr, n_filters=80, nfft=512)\n",
    "\n",
    "    res = torch.from_numpy(fbank).float()\n",
    "\n",
    "    if not isinstance(res, torch.Tensor):\n",
    "        raise TypeError(\"Return value must be torch.Tensor\")\n",
    "    return res\n",
    "\n",
    "\n",
    "class BZNSYP(Dataset):\n",
    "    def __init__(self, wav_file, text_file, tokenizer):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.wav2path = {}\n",
    "        self.wav2text = {}\n",
    "        self.ids = []\n",
    "\n",
    "        with open(wav_file, \"r\", encoding=\"utf-8\") as f:\n",
    "            for line in f:\n",
    "                parts = line.strip().split(\"\\t\", 1)\n",
    "                if len(parts) == 2:\n",
    "                    id = parts[0]\n",
    "                    self.ids.append(id)\n",
    "                    path = \"./dataset/\" + parts[1]\n",
    "                    self.wav2path[id] = path\n",
    "                else:\n",
    "                    raise ValueError(f\"Invalid line format: {line}\")\n",
    "\n",
    "        with open(text_file, \"r\", encoding=\"utf-8\") as f:\n",
    "            for line in f:\n",
    "                parts = line.strip().split(\"\\t\", 1)\n",
    "                if len(parts) == 2:\n",
    "                    id = parts[0]\n",
    "                    pinyin_list = parts[1].split(\" \")\n",
    "                    self.wav2text[id] = self.tokenizer([\"<sos>\"]+pinyin_list+[\"<eos>\"])\n",
    "                else:\n",
    "                    raise ValueError(f\"Invalid line format: {line}\")\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.wav2path)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        id = list(self.wav2path.keys())[index]\n",
    "        path = self.wav2path[id]\n",
    "        text = self.wav2text[id]\n",
    "        return id, extract_audio_features(path), text\n",
    "    \n",
    "\n",
    "def get_dataloader(wav_file, text_file, batch_size, tokenizer, shuffle=True):\n",
    "    dataset = BZNSYP(wav_file, text_file, tokenizer)\n",
    "    dataloader = DataLoader(\n",
    "        dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=shuffle,\n",
    "        collate_fn=collate_with_PAD\n",
    "    )\n",
    "    return dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = Tokenizer()\n",
    "dataloader = get_dataloader(\"./dataset/split/train/wav.scp\", \"./dataset/split/train/pinyin\", 3, tokenizer, shuffle = False)\n",
    "input = None\n",
    "for batch in dataloader:\n",
    "    input = batch\n",
    "    break\n",
    "\n",
    "audio_lens = input[\"audio_lens\"]\n",
    "audios = input[\"audios\"]\n",
    "texts = input[\"texts\"]\n",
    "text_lens = input[\"text_lens\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['ids', 'audios', 'audio_lens', 'texts', 'text_lens'])\n"
     ]
    }
   ],
   "source": [
    "print(input.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([265, 285, 439], dtype=torch.int32)\n",
      "tensor([[0.0000e+00, 5.4336e-10, 0.0000e+00,  ..., 8.3642e-09, 3.3812e-08,\n",
      "         9.6364e-08],\n",
      "        [0.0000e+00, 7.9692e-10, 0.0000e+00,  ..., 9.7250e-09, 1.6037e-08,\n",
      "         4.6284e-08],\n",
      "        [0.0000e+00, 1.0147e-09, 0.0000e+00,  ..., 8.5312e-09, 3.1178e-08,\n",
      "         7.5750e-08],\n",
      "        ...,\n",
      "        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,\n",
      "         0.0000e+00],\n",
      "        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,\n",
      "         0.0000e+00],\n",
      "        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,\n",
      "         0.0000e+00]])\n"
     ]
    }
   ],
   "source": [
    "# audio\n",
    "print(audio_lens)\n",
    "print(audios[0, : , :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['ka', 'e', 'er', 'pu', 'pei', 'wai', 'sun', 'wan', 'hua', 'ti']\n",
      "['<sos>', 'ka', 'e', 'er', 'pu', 'pei', 'wai', 'sun', 'wan', 'hua', 'ti', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>']\n",
      "['jia', 'yu', 'cun', 'yan', 'bie', 'zai', 'yong', 'bao', 'wo']\n",
      "['<sos>', 'jia', 'yu', 'cun', 'yan', 'bie', 'zai', 'yong', 'bao', 'wo', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n",
      "['bao', 'ma', 'pei', 'gua', 'bo', 'luo', 'an', 'diao', 'chan', 'yuan', 'zhen', 'dong', 'weng', 'ta']\n",
      "['<sos>', 'bao', 'ma', 'pei', 'gua', 'bo', 'luo', 'an', 'diao', 'chan', 'yuan', 'zhen', 'dong', 'weng', 'ta', '<eos>']\n"
     ]
    }
   ],
   "source": [
    "# text\n",
    "texts = texts.tolist()\n",
    "\n",
    "for text in texts:\n",
    "    print(tokenizer.decode(text))\n",
    "    print(tokenizer.decode(text, ignore_special=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
